from data import prompt
from model import model

def chat_with_image_test(image,prompt):
    messages = prompt.prepare_chat_data(prompt)
    multi_modal_model = model.MultiModalModel().to('cuda')
    vision_hidden_states, img_token_position, llm_outputs = multi_modal_model([messages], [image])
    generated_ids = llm_outputs.logits.argmax(dim=-1)
    response = multi_modal_model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return response